from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

import numpy as np

# Affichage de la table
def AfficheTable(S1, S2, D):
    n = len(S1)
    m = len(S2)

    # Créer une matrice RGB avec 2 lignes et 2 colonnes supplémentaires pour les en-têtes
    mat = np.zeros((n+3, m+3, 3))

    # Couleur des cases d'en-tête (bleu clair)
    header_color = [0.85, 0.85, 1.0]

    # Remplir les en-têtes
    for j in range(m+3):
        mat[0][j] = header_color  # Ligne des indices j
        mat[1][j] = header_color  # Ligne des caractères S2
    for i in range(n+3):
        mat[i][0] = header_color  # Colonne des indices i
        mat[i][1] = header_color  # Colonne des caractères S1

    # Remplir les données
    for i in range(n+1):
        for s in range(m+1):
            if (i, s) in D:
                mat[i+2][s+2] = [0.2, 0.7, 0.3]  # Vert
            else:
                mat[i+2][s+2] = [0.85, 0.85, 0.85]  # Gris clair

    plt.close('all')
    fig, ax = plt.subplots(figsize=((m+3)*0.4, (n+3)*0.4))
    ax.imshow(mat)

    # Afficher les valeurs dans chaque case de données
    for i in range(n+1):
        for s in range(m+1):
            if (i, s) in D:
                valeur = D[(i, s)]
                ax.text(s+2, i+2, str(int(valeur)), ha='center', va='center',
                        color='white', fontsize=9, fontweight='bold')

    # Afficher les indices j (ligne 0)
    for j in range(m+1):
        ax.text(j+2, 0, str(j), ha='center', va='center', color='black', fontsize=9)

    # Afficher les caractères S2 (ligne 1)
    for j in range(m+1):
        char = ' ' if j == 0 else S2[j-1]
        ax.text(j+2, 1, char, ha='center', va='center', color='red', fontsize=9, fontweight='bold')

    # Afficher les indices i (colonne 0)
    for i in range(n+1):
        ax.text(0, i+2, str(i), ha='center', va='center', color='black', fontsize=9)

    # Afficher les caractères S1 (colonne 1)
    for i in range(n+1):
        char = ' ' if i == 0 else S1[i-1]
        ax.text(1, i+2, char, ha='center', va='center', color='red', fontsize=9, fontweight='bold')

    # Labels des en-têtes (coin supérieur gauche)
    ax.text(0, 0, '', ha='center', va='center', color='black', fontsize=9, fontweight='bold')
    ax.text(1, 0, 'j', ha='center', va='center', color='black', fontsize=9, fontweight='bold')
    ax.text(0, 1, 'i', ha='center', va='center', color='black', fontsize=9, fontweight='bold')
    ax.text(1, 1, '', ha='center', va='center', color='black', fontsize=9, fontweight='bold')

    # Légendes des axes (en dehors de la matrice)
    ax.text((m+3)/2, -0.8, 'Préfixe S2', ha='center', va='center', color='black', fontsize=11, )
    ax.text(-0.8, (n+3)/2, 'Préfixe S1', ha='center', va='center', color='black', fontsize=11, rotation=90)

    # Quadrillage complet sur la matrice
    # Lignes horizontales
    for i in range(n+4):
        ax.plot([-0.5, m+2.5], [i-0.5, i-0.5], color='black', linewidth=1)
    # Lignes verticales
    for j in range(m+4):
        ax.plot([j-0.5, j-0.5], [-0.5, n+2.5], color='black', linewidth=1)

    # Cacher les axes et les spines (bordures du graphique)
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)

    # Ajuster les limites pour voir les légendes
    ax.set_xlim(-1.5, m+2.5)
    ax.set_ylim(n+2.5, -1.5)

    plt.title('Table de programmation dynamique', pad=10)
    plt.tight_layout()
    plt.show()


# Séquences ADN à comparer
seq1 = "ACGTAC"
seq2 = "AGTCAT"
L = {}  # Table de mémoïsation

reference = "ACGTAC"
candidates = ["ACGT", "CGTA", "GTAC", "TACG"]


##########################
# Mode bottom-up
##########################

def initialise_cas_de_base(S1, S2, L):
    n = len(S1)
    m = len(S2)

    for i in range(n+1):
        L[(i,0)] = 0
    for j in range(m+1):
        L[(0,j)] = 0

    return L

def remplir_table(S1, S2, L):
    n = len(S1)
    m = len(S2)

    for i in range(1,n+1):
        for j in range(1,m+1):
            # Cas du match
            if S1[i-1]==S2[j-1]:
                L[(i,j)] = L[(i-1,j-1)] + 1
            # Cas 2a et 2b
            else:
                V1 = L[(i-1,j)]
                V2 = L[(i,j-1)]
                L[(i,j)] = max(V1,V2)
    return L

# Question 5 : Complexité
# - Nombre de sous-problèmes : (n+1) * (m+1) = O(n*m)
# - Complexité temporelle : O(n*m) car chaque sous-problème est calculé en O(1)
# - Complexité spatiale : O(n*m) pour stocker la table
def lcs_bottomup(S1, S2):
    n = len(S1)
    m = len(S2)

    L = {}
    L = initialise_cas_de_base(S1,S2,L)
    L = remplir_table(S1,S2,L)

    return L[(n,m)]

def comparer_sequences(seq_reference, liste_sequences):
    seq_selectionnes = []
    distances = []

    d_min = -np.inf
    for seq in liste_sequences:
        dist = lcs_bottomup(seq_reference,seq)
        distances.append(dist)
        if dist > d_min:
            d_min = dist

    offset = 0
    for i in range(distances.count(d_min)):
        pos = distances.index(d_min,offset)
        seq_selectionnes.append(liste_sequences[pos])
        offset = pos + 1

    return seq_selectionnes, d_min


print(initialise_cas_de_base(seq1,seq2,L))
L = remplir_table(seq1,seq2,L)
AfficheTable(seq1,seq2,L)
print(lcs_bottomup(seq1,seq2))



############################
# Mode Top-down
############################

L = {}

# Question 2 : Comparaison bottom-up vs top-down
# L'approche top-down optimisée calcule moins de sous-problèmes car :
# - En cas de match, seul le sous-problème diagonal (i-1, j-1) est exploré
# - Les sous-problèmes (i-1, j) et (i, j-1) ne sont PAS calculés dans ce cas
# - Seuls les sous-problèmes réellement nécessaires sont calculés
#
# Cette différence est plus marquée quand les deux chaînes sont très similaires
# (beaucoup de matchs), car de nombreux sous-problèmes sont alors évités.

# Question 3 : Complexité de l'algorithme top-down
# - Complexité temporelle (pire cas) : O(n*m)
#   Car au maximum (n+1)*(m+1) sous-problèmes distincts, chacun calculé en O(1)
# - Complexité spatiale :
#   - Dictionnaire : O(n*m) dans le pire cas
#   - Pile d'appels : O(n+m) (profondeur maximale de récursion)
#   - Total : O(n*m) (le dictionnaire domine)


def rec_lcs(S1, S2):
    n = len(S1)
    m = len(S2)
    def f_rec(i, j):
        # Utilise la mémoïsation
        if (i,j) in L:
            return L[(i,j)]
        # Cas de base
        if i == 0 or j == 0:
            L[(i,j)] = 0
            return L[(i,j)]
        # Cas du match
        if S1[i-1] == S2[j-1]:
            L[(i,j)] = f_rec(i-1,j-1) + 1
            return L[(i,j)]
        # Sinon, calcule les deux possibilités
        else:
            V1 = f_rec(i-1,j)
            V2 = f_rec(i,j-1)
            # Mémoïse et retourne la valeur optimale
            L[(i,j)] = max(V1,V2)
            return L[(i,j)]
    longueur = f_rec(n,m)
    return longueur

def determiner_choix(S1, S2, L, i, j):
    n = len(S1)
    m = len(S2)

    if i > 0 and j > 0:
        # Cas n°1
        if S1[i-1] == S2[j-1] and L[(i,j)] == L[(i-1,j-1)] + 1:
            return ("GARDE " + S1[i-1], i-1, j-1)

        # Max des cas 2a et 2b
        elif L[(i-1,j)] >= L[(i,j-1)]:
            return ("LAISSE " + S1[i-1],i-1,j)
        else:
            return ("LAISSE " + S1[i-1],i,j-1)

# Question 3 : Complexité de la reconstruction
# - On parcourt au plus (n + m) étapes (à chaque étape, i ou j diminue d'au moins 1)
# - Chaque étape fait un travail en O(1)
# - Complexité temporelle : O(n + m)

# Question 4 : Complexité finale
# - Calcul des valeurs optimales : O(n*m)
# - Reconstruction : O(n + m)
# - Total : O(n*m) + O(n + m) = O(n*m)

def reconstruire_lcs(S1, S2, L):
    longueur = rec_lcs(S1,S2)

    seq = []
    i = len(S1)
    j = len(S2)

    while i > 0 and j > 0:
        op, i, j = determiner_choix(S1,S2,L,i,j)
        if "GARDE" in op:
            seq.append(S1[i])

    seq_txt = ""
    for i in range(len(seq)-1,-1,-1):
        seq_txt = seq_txt + seq[i]
    return seq_txt




rec_lcs(seq1,seq2)
AfficheTable(seq1,seq2,L)

print(determiner_choix(seq1,seq2,L,6,6))
print(determiner_choix(seq1,seq2,L,5,5))
print(determiner_choix(seq1,seq2,L,4,3))

print(reconstruire_lcs(seq1,seq2,L))
























